import operator
from typing import Annotated
from typing_extensions import TypedDict

from langgraph.types import Send
from langgraph.graph import END, StateGraph, START
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser, JsonOutputParser
from dotenv import load_dotenv

load_dotenv()

# Instantiate the model
model = ChatOpenAI(model="gpt-4o-mini", temperature=0)


# This will be the overall state of the main graph.
# It will contain a topic (which we expect the user to provide)
# and then will generate a list of subjects, and then a joke for
# each subject
class OverallState(TypedDict):
    topic: str
    num_subjects: int
    subjects: list
    # Notice here we use the operator.add
    # This is because we want combine all the jokes we generate
    # from individual nodes back into one list - this is essentially
    # the "reduce" part
    jokes: Annotated[list, operator.add]
    best_selected_joke: str


# This will be the state of the node that we will "map" all
# subjects to in order to generate a joke
class JokeState(TypedDict):
    subject: str


# This is the function we will use to generate the subjects of the jokes.
# In general the length of the list generated by this node could vary each run.
def generate_topics(state: OverallState):
    prompt = ChatPromptTemplate.from_messages([
        ("system", "You are a helpful assistant that generates a list of subjects based on a topic. Your output should be a JSON list of strings, for example: [\"subject1\", \"subject2\"]."),
        ("human", "Generate a list of {num_subjects} subjects related to the topic '{topic}'.")
    ])
    chain = prompt | model | JsonOutputParser()
    subjects = chain.invoke({"topic": state["topic"], "num_subjects": state["num_subjects"]})
    return {"subjects": subjects}


# Here we generate a joke, given a subject
def generate_joke(state: JokeState):
    prompt = ChatPromptTemplate.from_template("Tell me a short joke about {subject}")
    chain = prompt | model | StrOutputParser()
    joke = chain.invoke({"subject": state["subject"]})
    return {"jokes": [joke]}


# Here we define the logic to map out over the generated subjects
# We will use this as an edge in the graph
def continue_to_jokes(state: OverallState):
    # We will return a list of `Send` objects
    # Each `Send` object consists of the name of a node in the graph
    # as well as the state to send to that node
    return [Send("generate_joke", {"subject": s}) for s in state["subjects"]]


# Here we will judge the best joke
def best_joke(state: OverallState):
    jokes_str = "\n\n".join(state["jokes"])
    prompt = ChatPromptTemplate.from_template(
        "Here are a few jokes:\n\n{jokes}\n\nWhich one is the best? Just return the joke itself, and nothing else."
    )
    chain = prompt | model | StrOutputParser()
    best = chain.invoke({"jokes": jokes_str})
    return {"best_selected_joke": best}


# Construct the graph: here we put everything together to construct our graph
builder = StateGraph(OverallState)
builder.add_node("generate_topics", generate_topics)
builder.add_node("generate_joke", generate_joke)
builder.add_node("best_joke", best_joke)

builder.add_edge(START, "generate_topics")
builder.add_conditional_edges("generate_topics", continue_to_jokes, ["generate_joke"])

builder.add_edge("generate_joke", "best_joke")
builder.add_edge("best_joke", END)
graph = builder.compile()

# Call the graph: here we call it to generate a list of jokes
for step in graph.stream({"topic": "research", "num_subjects": 3}):
    print(step)